{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "95b08e64",
   "metadata": {},
   "source": [
    "# Filters\n",
    "\n",
    "Flax NNX uses [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) extensively as a way to create [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) groups in APIs, such as [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.state()`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.state), and many of the [Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html).\n",
    "\n",
    "In this guide you will learn how to:\n",
    "\n",
    "* Use [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) to group Flax NNX variables and states into subgroups;\n",
    "* Understand relationships between types, such as [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat), and [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html);\n",
    "* Express your `Filter`s flexibly with [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) language.\n",
    "\n",
    "In the following example [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) and [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat) are used as `Filter`s to split the model into two groups: one with the parameters and the other with the batch statistics:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "45485345",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "params = State({\n",
      "  'a': Param(\n",
      "    value=0\n",
      "  )\n",
      "})\n",
      "batch_stats = State({\n",
      "  'b': BatchStat(\n",
      "    value=True\n",
      "  )\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "from flax import nnx\n",
    "\n",
    "class Foo(nnx.Module):\n",
    "  def __init__(self):\n",
    "    self.a = nnx.Param(0)\n",
    "    self.b = nnx.BatchStat(True)\n",
    "\n",
    "foo = Foo()\n",
    "\n",
    "graphdef, params, batch_stats = nnx.split(foo, nnx.Param, nnx.BatchStat)\n",
    "\n",
    "print(f'{params = }')\n",
    "print(f'{batch_stats = }')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f77e99a",
   "metadata": {},
   "source": [
    "Let's dive deeper into `Filter`s."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0413d64",
   "metadata": {},
   "source": [
    "## The `Filter` Protocol\n",
    "\n",
    "In general, Flax `Filter`s are predicate functions of the form:\n",
    "\n",
    "```python\n",
    "\n",
    "(path: tuple[Key, ...], value: Any) -> bool\n",
    "\n",
    "```\n",
    "\n",
    "where:\n",
    "\n",
    "- `Key` is a hashable and comparable type;\n",
    "- `path` is a tuple of `Key`s representing the path to the value in a nested structure; and\n",
    "- `value` is the value at the path.\n",
    "\n",
    "The function returns `True` if the value should be included in the group, and `False` otherwise.\n",
    "\n",
    "Types are not functions of this form. They are treated as `Filter`s because, as you will learn in the next section, types and some other literals are converted to _predicates_. For example, [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) is roughly converted to a predicate like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "30f4c868",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "is_param((), nnx.Param(0)) = True\n"
     ]
    }
   ],
   "source": [
    "def is_param(path, value) -> bool:\n",
    "  return isinstance(value, nnx.Param)\n",
    "\n",
    "print(f'{is_param((), nnx.Param(0)) = }')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8a2641e",
   "metadata": {},
   "source": [
    "Such function matches any value that is an instance of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param). Internally Flax NNX uses `OfType` which defines a callable of this form for a given type:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b3095221",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "is_param((), nnx.Param(0)) = True\n"
     ]
    }
   ],
   "source": [
    "is_param = nnx.OfType(nnx.Param)\n",
    "\n",
    "print(f'{is_param((), nnx.Param(0)) = }')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87c06e39",
   "metadata": {},
   "source": [
    "## The `Filter` DSL\n",
    "\n",
    "Flax NNX exposes a small domain specific language ([DSL](https://en.wikipedia.org/wiki/Domain-specific_language)), formalized as the [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) type. This means users don't have to create functions like in the previous section.\n",
    "\n",
    "Here is a list of all the callable `Filter`s included in Flax NNX, and their corresponding DSL literals (when available):\n",
    "\n",
    "\n",
    "| Literal | Callable | Description |\n",
    "|--------|----------------------|-------------|\n",
    "| `...` or `True` | `Everything()` | Matches all values |\n",
    "| `None` or `False` | `Nothing()` | Matches no values |\n",
    "| `type` | `OfType(type)` | Matches values that are instances of `type` or have a `type` attribute that is an instance of `type` |\n",
    "| | `PathContains(key)` | Matches values that have an associated `path` that contains the given `key` |\n",
    "| `'{filter}'` <span style=\"color:gray\">str</span> | `WithTag('{filter}')` | Matches values that have string `tag` attribute equal to `'{filter}'`. Used by `RngKey` and `RngCount`. |\n",
    "| `(*filters)` <span style=\"color:gray\">tuple</span> or `[*filters]` <span style=\"color:gray\">list</span> | `Any(*filters)` | Matches values that match any of the inner `filters` |\n",
    "| | `All(*filters)` | Matches values that match all of the inner `filters` |\n",
    "| | `Not(filter)` | Matches values that do not match the inner `filter` |\n",
    "\n",
    "\n",
    "Let's check out the DSL in action by using [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) as an example. Consider the following:\n",
    "\n",
    "1) You want to vectorize all parameters;\n",
    "2) Apply `'dropout'` `Rng(Keys|Counts)` on the `0`th axis; and\n",
    "3) Broadcast the rest.\n",
    "\n",
    "To do this, you can use the following `Filter`s to define a `nnx.StateAxes` object that you can pass to `nnx.vmap`'s `in_axes` to specify how the `model`'s various sub-states should be vectorized:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d38b7694",
   "metadata": {},
   "outputs": [],
   "source": [
    "state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None})\n",
    "\n",
    "@nnx.vmap(in_axes=(state_axes, 0))\n",
    "def forward(model, x):\n",
    "  ..."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd60f0e1",
   "metadata": {},
   "source": [
    "Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...` expands to `Everything()`.\n",
    "\n",
    "If you wish to manually convert literal into a predicate, you can use [`nnx.filterlib.to_predicate`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html#flax.nnx.filterlib.to_predicate):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7e065fa9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "is_param = OfType(<class 'flax.nnx.variablelib.Param'>)\n",
      "everything = Everything()\n",
      "nothing = Nothing()\n",
      "params_or_dropout = Any(OfType(<class 'flax.nnx.variablelib.Param'>), WithTag('dropout'))\n"
     ]
    }
   ],
   "source": [
    "is_param = nnx.filterlib.to_predicate(nnx.Param)\n",
    "everything = nnx.filterlib.to_predicate(...)\n",
    "nothing = nnx.filterlib.to_predicate(False)\n",
    "params_or_dropout = nnx.filterlib.to_predicate((nnx.Param, 'dropout'))\n",
    "\n",
    "print(f'{is_param = }')\n",
    "print(f'{everything = }')\n",
    "print(f'{nothing = }')\n",
    "print(f'{params_or_dropout = }')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db9b4cf3",
   "metadata": {},
   "source": [
    "## Grouping `State`s\n",
    "\n",
    "With the knowledge of `Filter`s from previous sections at hand, let's learn how to roughly implement [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split). Here are the key ideas:\n",
    "\n",
    "* Use `nnx.graph.flatten` to get the [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) and [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) representation of the node.\n",
    "* Convert all the `Filter`s to predicates.\n",
    "* Use `State.flat_state` to get the flat representation of the state.\n",
    "* Traverse all the `(path, value)` pairs in the flat state and group them according to the predicates.\n",
    "* Use `State.from_flat_state` to convert the flat states to nested [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "068208fc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "params = State({\n",
      "  'a': Param(\n",
      "    value=0\n",
      "  )\n",
      "})\n",
      "batch_stats = State({\n",
      "  'b': BatchStat(\n",
      "    value=True\n",
      "  )\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "from typing import Any\n",
    "KeyPath = tuple[nnx.graph.Key, ...]\n",
    "\n",
    "def split(node, *filters):\n",
    "  graphdef, state = nnx.graph.flatten(node)\n",
    "  predicates = [nnx.filterlib.to_predicate(f) for f in filters]\n",
    "  flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]\n",
    "\n",
    "  for path, value in state:\n",
    "    for i, predicate in enumerate(predicates):\n",
    "      if predicate(path, value):\n",
    "        flat_states[i][path] = value\n",
    "        break\n",
    "    else:\n",
    "      raise ValueError(f'No filter matched {path = } {value = }')\n",
    "\n",
    "  states: tuple[nnx.GraphState, ...] = tuple(\n",
    "    nnx.State.from_flat_path(flat_state) for flat_state in flat_states\n",
    "  )\n",
    "  return graphdef, *states\n",
    "\n",
    "# Let's test it.\n",
    "foo = Foo()\n",
    "\n",
    "graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat)\n",
    "\n",
    "print(f'{params = }')\n",
    "print(f'{batch_stats = }')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b3aeac8",
   "metadata": {},
   "source": [
    "**Note:*** It's very important to know that **filtering is order-dependent**. The first `Filter` that matches a value will keep it, and therefore you should place more specific `Filter`s before more general `Filter`s.\n",
    "\n",
    "For example, as demonstrated below, if you:\n",
    "\n",
    "1) Create a `SpecialParam` type that is a subclass of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param), and a `Bar` object (subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)) that contains both types of parameters; and\n",
    "2) Try to split the [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s before the `SpecialParam`s\n",
    "\n",
    "then all the values will be placed in the [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) group, and the `SpecialParam` group will be empty because all `SpecialParam`s are also [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "014da4d4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "params = State({\n",
      "  'a': Param(\n",
      "    value=0\n",
      "  ),\n",
      "  'b': SpecialParam(\n",
      "    value=0\n",
      "  )\n",
      "})\n",
      "special_params = State({})\n"
     ]
    }
   ],
   "source": [
    "class SpecialParam(nnx.Param):\n",
    "  pass\n",
    "\n",
    "class Bar(nnx.Module):\n",
    "  def __init__(self):\n",
    "    self.a = nnx.Param(0)\n",
    "    self.b = SpecialParam(0)\n",
    "\n",
    "bar = Bar()\n",
    "\n",
    "graphdef, params, special_params = split(bar, nnx.Param, SpecialParam) # wrong!\n",
    "print(f'{params = }')\n",
    "print(f'{special_params = }')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9f0b7b8",
   "metadata": {},
   "source": [
    "And reversing the order will ensure that the `SpecialParam` are captured first:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a2ebf5b2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "params = State({\n",
      "  'a': Param(\n",
      "    value=0\n",
      "  )\n",
      "})\n",
      "special_params = State({\n",
      "  'b': SpecialParam(\n",
      "    value=0\n",
      "  )\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!\n",
    "print(f'{params = }')\n",
    "print(f'{special_params = }')"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "formats": "ipynb,md:myst"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
